# client.py
import os
import sys
import math
import copy
import torch
import numpy as np
import torch.nn.functional
import torch.utils.data
import torch.optim.lr_scheduler as lr_scheduler
sys.path.append('extra_utils')
from extra_utils.model import resnet34, resnet101, mnist_Net
from extra_utils.distributed_utils import init_distributed_mode, cleanup, is_main_process
import torch.distributed as dist
from tqdm import tqdm


class Client(object):

    def __init__(self, args, conf, train_loader, train_sampler, eval_loader, rank):

        self.conf = conf

        self.rank = rank

        self.train_loader = train_loader

        self.eval_loader = eval_loader

        self.train_sampler = train_sampler

    def local_train(self, model, args, global_epoch, cost_list, train_length, accuracy_list,
                    accuracy_1, accuracy_2, rank, cost_1, cost_2):
        # train_length是方便我们每隔一段时间画出损失函数点

        # for name, param in model.state_dict().items():  # 遍历模型参数
        #     self.local_model.state_dict()[name].copy_(param.clone())  # 将全局模型复制一份到本地训练的模型
        local_model = copy.deepcopy(model)  # 拷贝拷贝可变类型就是完完全全拷贝了一份，是完完全全的两个内存，互不干扰

        # 是否冻结权重  我们默认不冻结
        if args.freeze_layers:
            for name, para in local_model.named_parameters():
                # 除最后的全连接层外，其他权重全部冻结
                if "fc" not in name:
                    para.requires_grad_(False)  # 即只训练全连接层
            pg = [p for p in local_model.parameters() if p.requires_grad]  # 将我们需要训练的各层参数(全连接层)以 列表生成式的方式 生成列表。
            optimizer = torch.optim.Adam(pg, lr=self.conf['lr'])  # 注意此时更新的不是model模型，而是pg列表所对应的模型参数
        else:
            optimizer = torch.optim.Adam(local_model.parameters(), lr=self.conf['lr'])
            # 若冻结权重，则带有BN结构的网络不会被训练，而训练带有BN结构的网络时使用SyncBatchNorm才有意义
            if args.syncBN:
                # 使用SyncBatchNorm后训练会更耗时
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
                # 将模型中所有的BN层替换成具有同步功能的BN层

        lf = lambda x: ((1 + math.cos(x * math.pi / self.conf["local_epochs"])) / 2) * (1 - args.lrf) + args.lrf  # cosine
        # 利用lambda函数定义一个输入参数x和函数lf的关系。

        # 使用随机梯度下降算法作为优化算法
        # 需要训练的全连接层参数、初始学习率、动量、正则项
        criterion = torch.nn.CrossEntropyLoss()
        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

        local_model.train()  # 进入模型训练模式
        for epoch in range(self.conf["local_epochs"]):

            self.train_sampler.set_epoch(epoch)  # 这种方法使得我们的各个设备在每一轮所获得的数据都是不一样的。
            if is_main_process():
                self.train_loader = tqdm(self.train_loader, file=sys.stdout)

            total_loss = 0
            running_loss = 0
            local_loss = 0

            for batch_id, batch in enumerate(self.train_loader):
                inputs, target = batch
                if rank == 2:
                    for j in range(len(target)):  # 标签翻转攻击
                        if target[j] == 1:
                            target[j] = 4

                optimizer.zero_grad()  # 梯度清零
                output = local_model(inputs)  # 前向传播算法
                loss = criterion(output, target)  # 使用交叉熵计算损失值
                loss.backward()  # 反向传播计算得到了梯度
                total_loss += loss.item()
                running_loss += loss.item()

                optimizer.step()  # 利用反向传播得到的梯度，利用优化算法更新网络参数（权重）
                scheduler.step()  # 更新学习率

                if (batch_id + 1) % (train_length / (self.conf["batch_size"]*5*self.conf["world_size"])) == 0:
                    # 每次本地训练共取出5个点来绘制。
                    cost_list.append(running_loss)
                    # print(running_loss)
                    # 使用命令行参数控制是否打印其他进程图线
                    if not args.only_0:
                        # 用于绘制除rank0之外的其他进程的图线的Loss代码。
                        loss_1 = torch.FloatTensor([running_loss])  # 方便发送接收rank1的准确率
                        loss_2 = torch.FloatTensor([running_loss])  # 方便发送接收rank2的准确率
                        if is_main_process():
                            dist.recv(loss_1, src=1)  # 接收来源于rank1的数据，并将其覆盖于loss_1
                            dist.recv(loss_2, src=2)  # 接收来源于rank2的数据，并将其覆盖于loss_2
                        elif rank == 1:
                            dist.send(loss_1, dst=0)  # 如果是rank1，发送数据acc1
                        else:
                            dist.send(loss_2, dst=0)  # 如果是rank2，发送数据acc2
                        dist.barrier()
                        loss_1 = loss_1.item()  # 取出tensor的data数据。避免最后plt.plot出错如下。
                        loss_2 = loss_2.item()  # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
                        if is_main_process():
                            cost_1.append(loss_1)
                            cost_2.append(loss_2)
                    running_loss = 0

                    # 我们要对所有的进程执行此操作，否则会引起每个进程微小的误差。
                    # 此部分会极大地拉慢程序运行速度。
                    with torch.no_grad():
                        local_model.eval()  # 进入模型评估模式
                        correct = 0
                        dataset_size = 0
                        for batch_id, batch in enumerate(self.eval_loader):
                            inputs, target = batch
                            dataset_size += inputs.size()[0]
                            output = local_model(inputs)
                            local_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
                            pred = output.data.max(1)[1]
                            correct += pred.eq(target.data.view_as(pred)).sum().item()
                        acc = 100.0 * (float(correct) / float(dataset_size))  # 准确率
                        accuracy_list.append(acc)
                        # print(acc)
                        # 使用命令行参数控制是否打印其他进程图线
                        if not args.only_0:
                            # 用于绘制除rank0之外的其他进程的图线的准确率代码。
                            acc1 = torch.FloatTensor([acc])  # 方便发送接收rank1的准确率
                            acc2 = torch.FloatTensor([acc])  # 方便发送接收rank2的准确率
                            if is_main_process():
                                dist.recv(acc1, src=1)  # 接收来源于rank1的数据，并将其覆盖于acc1
                                dist.recv(acc2, src=2)  # 接收来源于rank2的数据，并将其覆盖于acc2
                            elif rank == 1:
                                dist.send(acc1, dst=0)  # 如果是rank1，发送数据acc1
                            else:
                                dist.send(acc2, dst=0)  # 如果是rank2，发送数据acc2
                            dist.barrier()
                            acc1 = acc1.item()  # 取出tensor的data数据。避免最后plt.plot出错如下。
                            acc2 = acc2.item()  # VisibleDeprecationWarning: Creating an ndarray from ragged nested...
                            if is_main_process():
                                accuracy_1.append(acc1)
                                accuracy_2.append(acc2)
                                # print(acc)
                                # print(acc1)
                                # print(acc2)

                    local_loss = 0

            dist.barrier()  # 防止打印进度条的时候，会由于其他进程先训练完成而导致打印出来的字符串混乱掉进度条。

            if is_main_process():
                print("Rank %d, Global_epoch [%d/%d], Local Epoch [%d/%d] loss : %f."
                      % (self.rank, global_epoch + 1, self.conf["global_epochs"], epoch + 1,
                         self.conf["local_epochs"], total_loss))
            # 各进程loss大小差距不大，个人认为应该是因为迁移学习所造成的

        # 取出权重矩阵
        weight_matrix = []
        bias_matrix = []
        for name, parm in local_model.named_parameters():
            if (name == 'module.fc3.weight') | (name == 'fc3.weight'):  # 可以将最后一层全连接层的权重矩阵取出
                weight_matrix = parm.detach()
            elif (name == 'module.fc3.bias') | (name == 'fc3.bias'):
                bias_matrix = parm.detach()
                bias_matrix = bias_matrix.unsqueeze(1)
                if is_main_process():
                    print(bias_matrix.size())
                # tensor.detach()返回一个新的tensor，从当前计算图中分离下来的，但是仍指向原变量的存放位置,
                # 不同之处只是requires_grad为false，得到的这个tensor永远不需要计算其梯度，不具有grad。
        mean = torch.ones([weight_matrix.size()[1], 1])
        mean = mean / weight_matrix.size()[1]
        class_mean = weight_matrix.mm(mean)  # 利用矩阵乘法求得 此客户端下 该轮次的 类平均值torch.Size([10, 1])
        class_mean = class_mean + bias_matrix / weight_matrix.size()[1]

        diff = dict()  # 生成一个空的字典
        for name, data in local_model.state_dict().items():  # 遍历更新之后的各层模型参数。并返回每层对应的名字(name)和数据。
            # print(data != model.state_dict()[name])  # 用于打印出来是否参数相等
            diff[name] = (data - model.state_dict()[name])  # 将当前name和全局模型所对应name的数据进行相减，得到权重大小的变化量即权重差
        # print(diff[name])

        return diff, class_mean  # 返回网络参数的变化.value为tensor类型

    # 模型评估
    @torch.no_grad()  # 装饰器的方法实现with.no_grad()
    # 我们模型评估的时候使用全部测试集
    def model_eval(self, model):
        model.eval()  # 进入模型评估模式

        total_loss = 0.0
        correct = 0
        dataset_size = 0
        if is_main_process():
            self.eval_loader = tqdm(self.eval_loader, file=sys.stdout)

        for batch_id, batch in enumerate(self.eval_loader):  # batch_id就为enumerate()遍历集合所返回的批量序号
            inputs, target = batch  # 得到数据集和标签
            dataset_size += inputs.size()[0]  # data.size()=[batch,通道数,32,32]、target.size()=[batch]

            output = model(inputs)

            if self.conf["type"] == "mnist":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            elif self.conf["type"] == "flower":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            elif self.conf["type"] == "cifar":
                total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
            else:
                raise TypeError("Not find Appropriate mode.")
                # sum up batch loss
            # .data意即将变量的tensor取出来
            # 因为tensor包含data和grad，分别放置数据和计算的梯度
            pred = output.data.max(1)[1]  # get the index of the max log-probability
            # 按照从左往右的 第一维 取出最大值的索引 torch.max()
            correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
        # torch.view_as(tensor)即将调用函数的变量，转变为同参数tensor同样的形状
        # torch.eq()对两个张量tensor进行逐元素比较，如果相等则返回True，否则返回False。True和False作运算时可以作1、0使用
        # .cpu()这一步将预测结果放到cpu上，利用电脑内存存储列表值。从而避免测试过程中爆显存。
        # .sum()是将我们一个批量的预测值求和，便于累加到correct变量中。
        # .item()取出 单元素张量的元素值 并返回该值，保持原元素类型不变。

        acc = 100.0 * (float(correct) / float(dataset_size))  # 准确率
        aver_loss = total_loss / dataset_size  # 平均损失

        return acc, aver_loss
